-
Notifications
You must be signed in to change notification settings - Fork 6.8k
[MXNET-753] Fallback when using non-MKLDNN supported operators #12019
Conversation
@azai91 Very nice improvements! btw, I see you only add the attr in conv. In general, we need to add for all MKLDNN OP, right? @ZhennanQin @TaoLv for further comments. |
@azai91 Good improvement before subgraph ready. I'm not familiar with engine part, can this change cover all kinds of executing scenarios, like any combination of NaiveEngine, ThreadedEngine, symbolic, gluon, hybridized gluon, dynamic memory allocation, static memory allocation? |
@@ -351,6 +351,13 @@ static inline void InvalidateOutputs(const std::vector<NDArray> &arrs, | |||
} | |||
} | |||
|
|||
static inline std::vector<NDArray> InvalidateInputs(const std::vector<NDArray> &arrs) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The name is confusing and suggests mutation on the inputs.
src/executor/attach_op_execs_pass.cc
Outdated
@@ -226,6 +228,11 @@ class FComputeExExecutor : public OpExecutor { | |||
op_ctx.run_ctx = rctx; | |||
#if MXNET_USE_MKLDNN == 1 | |||
InvalidateOutputs(out_array, req); | |||
const auto is_mkldnn = Op::GetAttr<bool>("TIsMKLDNN"); | |||
if (!is_mkldnn.get(attrs_.op, false)) { | |||
fcompute_(attrs_, op_ctx, InvalidateInputs(in_array), req, out_array); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it thread safe to modify inputs?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nvm
ec1d179
to
3bc9d6d
Compare
tests/python/mkl/test_mkldnn.py
Outdated
|
||
data = mx.symbol.Variable('data') | ||
conv = mx.sym.Convolution(data=data, kernel=(5, 5), pad=(1, 1), stride=(1,1), num_filter=8, name="conv", no_bias=True) | ||
mlp = mx.symbol.Custom(name='custom', data=conv, op_type='custom') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: renaming mlp to custom or custom_op
@@ -299,7 +299,11 @@ class UnaryOp : public OpBase { | |||
} | |||
break; | |||
case kWriteInplace: | |||
// cannot check if ptrs are the same for MKLDNN because we may have | |||
// created copies of input when reordering. WriteInPlace will still write to original array | |||
#if MXNET_USE_MKLDNN != 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
change to #if MXNET_USE_MKLDNN == 0
src/executor/attach_op_execs_pass.cc
Outdated
@@ -40,6 +40,11 @@ const OperatorProperty* OpPropGetOpProperty(const NodeAttrs& attrs); | |||
|
|||
namespace exec { | |||
|
|||
class MKLDNNOpExecutor : public OpExecutor { | |||
protected: | |||
std::vector<NDArray> in_array_fallback; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need an extra copy rather than just editing the in_array directly?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
there is a race condition if we attempt to reorder the read_var
in place. other operators may be trying to read from it the same time (since they are expected to be read only consts)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If that's the case, why do we even bother to create a subclass with this extra data member?
@@ -182,6 +182,7 @@ The following activation functions are supported: | |||
}) | |||
.set_attr<FCompute>("FCompute<cpu>", ActivationCompute<cpu>) | |||
#if MXNET_USE_MKLDNN == 1 | |||
.set_attr<bool>("TIsMKLDNN", true) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a more generic way to add this instead of doing this for all operators? What if we later we add new operators, should we document this somewhere?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it only needs to be added to MKLDNN operators. this fix was a temporary solution while we get the subgraph feature implemented. we weighed the pros / cons of waiting to release a stable MKLDNN stable with this hack or waiting another month for subgraph to be introduced (possibly with it's own bugs) and decided we would use this short term solution.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please mark this TODO and create a JIRA ticket to remove this later after MKLDNN support is released.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added TODO to opexecuter
src/executor/attach_op_execs_pass.cc
Outdated
public: | ||
void Run(RunContext rctx, bool is_gpu) override { | ||
op_ctx.run_ctx = rctx; | ||
#if MXNET_USE_MKLDNN == 1 | ||
InvalidateOutputs(out_array, req); | ||
in_array_fallback = CreateDefaultInputs(in_array); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we just define the array here instead of creating an extra data member in the class?
src/executor/attach_op_execs_pass.cc
Outdated
@@ -153,12 +158,15 @@ class StatefulComputeExecutor : public StorageFallbackOpExecutor { | |||
|
|||
|
|||
// stateful compute_ex executor | |||
class StatefulComputeExExecutor : public OpExecutor { | |||
class StatefulComputeExExecutor : public MKLDNNOpExecutor { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Making this class derived from MKLDNNOpExecutor while it's only valid when MXNET_USE_MKLDNN == 1 violates the inheritance class is-a relationship
src/executor/exec_pass.h
Outdated
@@ -86,6 +86,9 @@ class OpExecutor { | |||
virtual OpStatePtr state() const { | |||
return OpStatePtr(); | |||
} | |||
|
|||
protected: | |||
std::vector<NDArray> in_array_fallback; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we really need this as a class data member? Or can we just declare it as a local variable when it is created because I don't see it used anywhere else in the flow.
src/executor/attach_op_execs_pass.cc
Outdated
@@ -159,6 +159,9 @@ class StatefulComputeExExecutor : public OpExecutor { | |||
op_ctx.run_ctx = rctx; | |||
#if MXNET_USE_MKLDNN == 1 | |||
InvalidateOutputs(out_array, req); | |||
in_array_fallback = CreateDefaultInputs(in_array); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we just declare std::vector in_array_fallback here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we need the fallback arrays to stay in memory or else we segfault.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cannot see why that is the case. But creating a data member is not a good solution to this. I am not very comfortable of intruding a new data member to this class without much justification. If this PR is only for a temporary workaround I am okay to approve. But please create a JIRA ticket to remove this hack later.
59f18a5
to
c356e6c
Compare
src/executor/exec_pass.h
Outdated
@@ -86,6 +86,9 @@ class OpExecutor { | |||
virtual OpStatePtr state() const { | |||
return OpStatePtr(); | |||
} | |||
|
|||
protected: | |||
std::vector<NDArray> in_array_fallback; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please make sure to add TODO with JIRA number to remove this hack.
@@ -356,6 +356,17 @@ static inline void InvalidateOutputs(const std::vector<NDArray> &arrs, | |||
} | |||
} | |||
|
|||
static inline std::vector<NDArray> CreateDefaultInputs(const std::vector<NDArray> &arrs) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also add TODO to remove this unnecessary function when final solution is implemented.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM (given that this is a temporary workaround before the subgraph optimization is integrated and related JIRA been assigned)
@@ -356,6 +356,18 @@ static inline void InvalidateOutputs(const std::vector<NDArray> &arrs, | |||
} | |||
} | |||
|
|||
// TODO(alexzai): (MXNET-856) Remove helper function after subgraph feature added | |||
static inline std::vector<NDArray> CreateDefaultInputs(const std::vector<NDArray> &arrs) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we pass a pointer to a vector as an argument instead of returning a vector.
fe7f954
to
da69adb
Compare
bea0609
to
f3e55e9
Compare
…e#12019) * add fallback test * wait to read throws error * add TIsMKLDNN attr * invalidate inputs if fcomputeex unsupported * keep ptr to newly created default arrays * add flag to all mkldnn operators * update method name to CreateDefaultInputs * remove dup attrs * create new instance var to store copy * only reorder if mkldnn * add mkldnn flag to batch norm * do not check input / output ptr for mkldnn as copied is made * fix lint * update macro * update custom update name * add todo for fallback * fix lint * rename opexecutor name * add fallback to opexecutor class * fix lint * add todos * create fallback arrays in place * revert in place diff * create copy of arrays for fallback * empty array
…e#12019) * add fallback test * wait to read throws error * add TIsMKLDNN attr * invalidate inputs if fcomputeex unsupported * keep ptr to newly created default arrays * add flag to all mkldnn operators * update method name to CreateDefaultInputs * remove dup attrs * create new instance var to store copy * only reorder if mkldnn * add mkldnn flag to batch norm * do not check input / output ptr for mkldnn as copied is made * fix lint * update macro * update custom update name * add todo for fallback * fix lint * rename opexecutor name * add fallback to opexecutor class * fix lint * add todos * create fallback arrays in place * revert in place diff * create copy of arrays for fallback * empty array
Description
-convert all mkldnn special format to default in using non-mkldnn operator
Checklist
Essentials
Please feel free to remove inapplicable items for your PR.
Changes
Comments